{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "CachedDataset",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wu3njRHLcIcr",
        "colab_type": "text"
      },
      "source": [
        "# Using The CachedDataset In Resource Constrained Environments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GKMmNL09cakQ",
        "colab_type": "text"
      },
      "source": [
        "Many times a ML model training happen on exactly the same dataset, with exactly the same transofrmations happening on the raw data.\n",
        "\n",
        "When the transformations applied to the raw data require considerable amount of CPU and/or RAM resources, and when the environment is scarse on those resources, it is possible to trade CPU/RAM with storage/network by using a *CachedDataset*.\n",
        "\n",
        "A *CachedDataset* wraps any existing *PyTorch* *Dataset*, by transparently caching the training samples, so that after the dataset is fully cached, there won't be any more CPU/RAM resources used to process it.\n",
        "\n",
        "A *CachedDataset* can also reveal itself useful even in cases where there is enough CPU/RAM available, as if the raw data processing performed from the input pipeline is heavy, there will still benefit in loading from storage the cooked data.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sPJVqAKyml5W",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "VERSION = \"nightly\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "!python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "07WeJ77alpkR",
        "colab_type": "text"
      },
      "source": [
        "A *CachedDataset* can be used transparently, by wrapping an existing *PyTorch* *Dataset*:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KEb7ZOKRlovh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.utils.cached_dataset as xcd\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "def _mp_fn(index):\n",
        "  train_dataset = datasets.MNIST(\n",
        "      '/tmp/mnist-data',\n",
        "      train=True,\n",
        "      download=True,\n",
        "      transform=transforms.Compose(\n",
        "              [transforms.ToTensor(),\n",
        "               transforms.Normalize((0.1307,), (0.3081,))]))\n",
        "  train_dataset = xcd.CachedDataset(train_dataset, '/tmp/cached-mnist-data')\n",
        "  # Here it follow the normal model code ...\n",
        "\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(), start_method='fork', nprocs=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZNDoQHstfsSu",
        "colab_type": "text"
      },
      "source": [
        "Example use of populating a CachedDataset whose cache folder can be exported to other locations:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zQ2_OcQxMEI8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.utils.cached_dataset as xcd\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "def _mp_fn(index):\n",
        "  train_dataset = datasets.MNIST(\n",
        "      '/tmp/mnist-data',\n",
        "      train=True,\n",
        "      download=True,\n",
        "      transform=transforms.Compose(\n",
        "              [transforms.ToTensor(),\n",
        "               transforms.Normalize((0.1307,), (0.3081,))]))\n",
        "  cached_dataset = xcd.CachedDataset(train_dataset, '/tmp/cached-mnist-data')\n",
        "  print('Warming up ...')  \n",
        "  cached_dataset.warmup()\n",
        "  print('Done!')\n",
        "\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(), start_method='fork', nprocs=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xE4KFwWak5Sp",
        "colab_type": "text"
      },
      "source": [
        "The *CachedDataset* generated in **/tmp/cached-mnist-data** can then be packed and use in other setups.\n",
        "\n",
        "A *CachedDataset* uses the PyTorch serialization to save samples, so it is portable in every machine where PyTorch is.\n",
        "\n",
        "Simply use *tar* to pack it:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kL9sxjZako-o",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!tar czf cached-mnist.tar.gz /tmp/cached-mnist-data/"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UInwgBTAmR1p",
        "colab_type": "text"
      },
      "source": [
        "The fully cached *CachedDataset* can then be used in other machines, even without the need of instantiating the existing *Dataset* (simply pass *None* as source *Dataset* object):"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-X8u5qauoXGk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.utils.cached_dataset as xcd\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "\n",
        "def _mp_fn(index):\n",
        "  train_dataset = xcd.CachedDataset(None, '/tmp/cached-mnist-data')\n",
        "  # Here it follow the normal model code ...\n",
        "\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(), start_method='fork', nprocs=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L2g56bHZoum1",
        "colab_type": "text"
      },
      "source": [
        "The XLA CachedDataset implementation natively supports GCS (Google Cloud Storage) as storage destination/source.\n",
        "\n",
        "Simply prefix the paths with gs:// and make sure the proper environment is setup to access GCS:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S2aEnnjkpKY0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!export GOOGLE_APPLICATION_CREDENTIALS=/PATH/TO/CREDENTIALS_JSON"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ur7GGOsdpQNI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.utils.cached_dataset as xcd\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "def _mp_fn(index):\n",
        "  train_dataset = datasets.MNIST(\n",
        "      '/tmp/mnist-data',\n",
        "      train=True,\n",
        "      download=True,\n",
        "      transform=transforms.Compose(\n",
        "              [transforms.ToTensor(),\n",
        "               transforms.Normalize((0.1307,), (0.3081,))]))\n",
        "  train_dataset = xcd.CachedDataset(train_dataset, 'gs://my_bucket/cached-mnist-data')\n",
        "  # Here it follow the normal model code ...\n",
        "\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(), start_method='fork', nprocs=1)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}